import java.util.*;

public class Solution528 {
    int[] prefixArr;
    int total = 0;
    Random random = new Random();

    public Solution528(int[] w) {
        prefixArr = new int[w.length];
        int sum = 0;
        for (int i = 0; i < w.length; i++) {
            sum += w[i];
            prefixArr[i] = sum;
        }
        total = sum;
    }

    private int binarySearch(int target){
        int left = 0, right = prefixArr.length-1;
        while(left <= right){
            int mid = left + (right - left)/2;
            int tmpNum = prefixArr[mid];
            if(tmpNum == target){
                return mid;
            }
            else if(target < tmpNum){
                right = mid - 1;
            }
            else{
                left = mid + 1;
            }
        }
        return left;
    }

    public int pickIndex() {
        return binarySearch(random.nextInt(total)+1);
    }

    public static void main(String[] args) {
        Solution528 s = new Solution528(new int[]{1, 2, 3});
        Map<Integer, Integer> map = new HashMap<>();
        for (int i = 0; i < 6000000; i++) {
            int index = s.pickIndex();
            map.put(index, map.getOrDefault(index, 0)+1);
        }
        System.out.println(s);
    }
}
